/*
 Copyright 2014-Present Algorithms in Motion LLC
 
 This file is part of FDTD++.
 
 FDTD++ is proprietary software: you can use it and/or modify it
 under the terms of the Algorithms in Motion License as published by
 Algorithms in Motion LLC, either version 1 of the License, or (at your
 option) any later version.
 
 FDTD++ is distributed in the hope that it will be useful,
 but WITHOUT ANY WARRANTY; without even the implied warranty of
 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 Algorithms in Motion License for more details.
 
 You should have received a copy of the Algorithms in Motion License
 along with FDTD++. If not, see <http://www.aimotionllc.com/licenses/>.
*/
// CREATED    : 9/7/2015
// LAST UPDATE: 9/30/2015

#include "statsxx/machine_learning/restricted_Boltzmann_machine/RBM.hpp"

// STL
#include <algorithm> // std::min(), std::max()
#include <cmath>     // std::sqrt(), std::fabs(), std::round()
#include <limits>    // std::numeric_limits<>::max()
#include <numeric>   // std::iota(), 
#include <ostream>   // std::ostream
#include <tuple>     // std::tuple<>, std::make_tuple(), std::tie()

// jScience
#include "jScience/linalg.hpp"        // Matrix<>, Vector<>, outer_product()
#include "jScience/math/utility.hpp"  // ngroups()
#include "jScience/stats/data.hpp"    // shuffle_data()
//#include "jScience/stl/ostream.hpp"   // NullStream


// basic algorithm for stochastic gradient descent
// note: momentum is used to help make adjustments along the average direction of movements
// note: I tested the automatic update of learning rate and momentum as described in Timothy Masters' book "DBN in C++ ...", but it seemed to lead to WORSE solutions so was commented (or taken) out -- this is OK, I guess since it was heuristic anyway
//
// note: see the header file for this file (e.g., header.hpp) for more detailed comments on parameters
//
// hardwired features:
//     - the fraction of time each neuron is on (for sparsity) is exponentially smoothed
inline void restricted_Boltzmann_machine::RBM::stochastic_gradient(
                                                                   const int max_epochs,             // maximum number of epochs
                                                                   const int nbatches,               // number of batches per epoch
                                                                   // -----
                                                                   double lr,                        // learning rate
                                                                   double momentum,                  // momentum
                                                                   // -----
                                                                   const double weight_penalty,      // weight penalty
                                                                   // -----
                                                                   const double sparsity_target,     // sparsity target
                                                                   const double sparsity_penalty,    // sparsity penalty
                                                                   const double sparsity_multiplier, // multiplier for q (hidden activation) EMA
                                                                   const double q_min,               // minimum (desired) q
                                                                   const double q_max,               // maximum (desired) q
                                                                   const double q_penalty,           // q penalty
                                                                   // -----
                                                                   const int K_begin,                // number of starting Monte Carlo (MC) itertions
                                                                   const int K_end,                  // number of ending Monte Carlo itertions
                                                                   const double K_rate,              // rate of updates to number of MC iterations
                                                                   // -----
                                                                   const bool mean_field,            // mean-field approximation
                                                                   // -----                                 
                                                                   const double convg_criterion,     // convergence criterion
                                                                   const int max_no_improvement,     // max number of epochs to go without an improvement in convergence
                                                                   // -----
                                                                   const Matrix<double> &X,          // data
                                                                   const bool x_binomial,            // whether the data is binomial
                                                                   // -----
                                                                   std::ostream* os            // (optional) output stream
                                                                   )
{
    //=========================================================
    // ERROR CHECKS
    //=========================================================
    
    // note: there are MANY error checks that one could do (such as the one below for learning rate), but I am just going to leave it open 
/*
    if(lr > 1.0)
    {
        std::clog << "warning in stochastic_gradient(): lr > 1.0; setting lr = 1.0" << '\n';
        lr = 1.0;
    }
    else if(lr < 0.001)
    {
        std::clog << "warning in stochastic_gradient(): lr < 0.001; setting lr = 0.001" << '\n';
        lr = 0.001;
    }
*/
     
    //=========================================================
    // INITIALIZATION
    //=========================================================
    
    // setup indices for batches
    std::vector<int> shuffle_idx(X.size(0));
    std::iota(shuffle_idx.begin(), shuffle_idx.end(), 0);
    
    std::vector<int> batch_begin;
    std::vector<int> batch_end;
    std::vector<int> batch_size;
    std::tie(batch_begin, batch_end, batch_size) = ngroups(X.size(0), nbatches);

    // variables for sparsity
    Vector<double> q_smoothed(c.size(), 0.0);
    
    // note: the sparsity target must be defined for each hidden neuron
    Vector<double> x_sp_target(c.size(), sparsity_target);
    
    // note: we also need the mean of each input over the entire training set
    Vector<double> x_mean(b.size(), 0.0);
    
    // << jmm: I really need to define the mean() functions, etc. for Vector<> ... then you could call mean() function using each column of X >>
    for(auto i = 0; i < X.size(0); ++i)
    {
        for(auto j = 0; j < X.size(1); ++j)
        {
            x_mean(j) += X(i,j);
        }
    }
    
    x_mean /= X.size(0);
    
    // initialize the number of Monte Carlo iterations
    int K = K_begin;
    
    // note: weight update increments are needed from epoch-to-epoch to use momentum
    Matrix<double> W_inc(W.size(0), W.size(1), 0.0);
    Vector<double> b_inc(b.size(), 0.0);
    Vector<double> c_inc(c.size(), 0.0);
    
/*
    // variables for automatic adjustment of learning rate
    Matrix<double> W_grad_prev;
    double length_prev = 0.0;
*/
 
    // set the best (max_inc/max_weight) ratio to the max it could ever be
    double best_criterion = std::numeric_limits<double>::max();
    
    int nno_improvement = 0;
    
    
    //=========================================================
    // MAIN LOOP
    //=========================================================
    
    //---------------------------------------------------------
    // LOOP OVER EPOCHS
    //---------------------------------------------------------
    
    for(int epoch = 0; epoch < max_epochs; ++epoch)
    {
        // note: every epoch we shuffle the data (the batches) in order to remove any possible serial correlation or accidental groupings of cases
        shuffle_data(shuffle_idx);
        
        double error = 0.0;        
        
        //---------------------------------------------------------
        // LOOP OVER BATCHES
        //---------------------------------------------------------
        
        double max_inc = 0.0; // used to test convergence
        
        for(int batch = 0; batch < nbatches; ++batch)
        {
            // it is important to note that all quantities (*_inc, q_smoothed, etc.) are updated every batch ... one may therefore expect things to work best if the number of cases per batch is large enough such that error variances in each quantity (from batch-to-batch) are small 
            
            // ACCUMULATE GRADIENTS, HIDDEN NEURON ACTIVATIONS, AND ERROR (PER CASE) FOR THIS BATCH
            Matrix<double> W_grad(this->W.size(0), this->W.size(1), 0.0);
            Vector<double> b_grad(this->b.size(), 0.0);
            Vector<double> c_grad(this->c.size(), 0.0);
            Vector<double> q_data(this->c.size(), 0.0);
            double error_batch = 0.0;
            
            for(int i = batch_begin[batch]; i <= batch_end[batch]; ++i)
            {
                Matrix<double> Wi_grad;
                Vector<double> bi_grad;
                Vector<double> ci_grad;
                Vector<double> qi_data; // q is used for sparsity
                double error_i;
                std::tie(Wi_grad, bi_grad, ci_grad, qi_data, error_i) = this->grad_log_likelihood(K, mean_field, X.row(shuffle_idx[i]), x_binomial);
                
                W_grad += Wi_grad;
                b_grad += bi_grad;
                c_grad += ci_grad;
                q_data += qi_data;
                error_batch += error_i;
            }
            
            // note: normalization ensures that we can consistently define all penalties, learning rates, etc.
            W_grad /= batch_size[batch];
            b_grad /= batch_size[batch];
            c_grad /= batch_size[batch];
            q_data /= batch_size[batch];
            error_batch /= batch_size[batch];
            
            // PENALIZE LARGE WEIGHTS
            W_grad += weight_penalty*this->W;
            
            // ENCOURAGE SPARSITY 

            // note: the fraction of time each neuron is on is exponentially smoothed (using the standard EMA formula)
            // << jmm: perhaps could update here to use a zero-lag moving average instead of an EMA -- perhaps would be too expensive though >>
            q_smoothed = (q_data - q_smoothed)*sparsity_multiplier + q_smoothed;
            
            Vector<double> x_sp_pen = sparsity_penalty*(q_smoothed - x_sp_target);
            
            for(auto i = 0; i < q_data.size(); ++i)
            {
                if(q_data(i) < q_min)
                {
                    x_sp_pen(i) += q_penalty*(q_data(i) - q_min);
                }
                else if(q_data(i) > q_max)
                {
                    x_sp_pen(i) += q_penalty*(q_data(i) - q_max);
                }
            }
            
            W_grad += outer_product(x_sp_pen, x_mean);
            c_grad += x_sp_pen;
            
            // COMPUTE INCREMENTS, AND UPDATE WEIGHTS, BIASES, AND ERROR FOR THIS BATCH
            W_inc = momentum*W_inc - lr*W_grad;
            b_inc = momentum*b_inc - lr*b_grad;
            c_inc = momentum*c_inc - lr*c_grad;
            
            // check also the max W_inc, needed for convergence
            for(auto i = 0; i < W_inc.size(0); ++i)
            {
                for(auto j = 0; j < W_inc.size(1); ++j)
                {
                    double w_inc = std::fabs(W_inc(i,j));
                    if(w_inc > max_inc)
                    {
                        max_inc = w_inc;
                    }
                }
            }
            
            this->W += W_inc;
            this->b += b_inc;
            this->c += c_inc;
            
            error += error_batch;
            
            
            
/*
            // UPDATE LEARNING RATE & MOMENTUM
            // note: this algorithm is borrowed from Timothy Masters' book "DBN in C++ ..." -- see pp. 138-139
            
            // << jmm: here and below is where having a dot product operation between two matrices would be beneficial >>
            
            double length_this = 0.0;
            
            for(auto i = 0; i < W.size(0); ++i)
            {
                for(auto j = 0; j < W.size(1); ++j)
                {
                    length_this += W_grad(i,j)*W_grad(i,j);
                }
            }
            
            length_this = std::sqrt(length_this);
            
            double denom = length_this*length_prev;
            
            // note: in practice, the gradient will probably never vanish, so really this checks if we are in the first batch of the first epoch
            if(denom != 0.0)
            {
                double dot = 0.0;
                
                for(auto i = 0; i < W.size(0); ++i)
                {
                    for(auto j = 0; j < W.size(1); ++j)
                    {
                        dot += W_grad(i,j)*W_grad_prev(i,j);
                    }
                }

                dot /= denom;
                
                // update learning rate and momentum
                if(dot > 0.5)
                {
                    lr *= 1.2;
                }
                else if(dot > 0.3)
                {
                    lr *= 1.1;
                }
                else if(dot < -0.5)
                {
                    lr /= 1.2;
                }
                else if(dot < -0.3)
                {
                    lr /= 1.1;
                }
                
                if(std::fabs(dot) > 0.3)
                {
                    momentum /= 1.5;
                }
            }
            
            // assign W_grad_prev and length_prev before continuing
            W_grad_prev = W_grad;
            length_prev = length_this;
*/
        }
        
        error /= nbatches;
        *os << error << '\n';
//        std::cout << "error: " << error << '\n';
        
        // CHECK FOR CONVERGENCE
        // note: as described in Timothy Masters' "DBN in C++ ..." book, there are two convergence criteria to consider: (i) check the max gradient relative to the max weight, which indicates the worst that randomness can to when one is near an optimum (secondary test); and (ii) quit if after a certain number of epochs we have no further change in the aforementioned ratio, because eventually randomness will provide a very small gradient/weight ratio that will be difficult to further reduce (primary test)
        
        double max_weight = 0.0;
        
        for(auto i = 0; i < this->W.size(0); ++i)
        {
            for(auto j = 0; j < this->W.size(1); ++j)
            {
                double w = std::fabs(this->W(i,j));
                if(w > max_weight)
                {
                    max_weight = w;
                }
            }
        }
        
        double max_inc_max_weight = max_inc/max_weight;
        
        // convergence criteria #1
        if(max_inc_max_weight < convg_criterion)
        {
            std::cout << "breaking on condition (max_inc_max_weight < convg_criterion)" << '\n';
            break;
        }
        
        if(max_inc_max_weight < best_criterion)
        {
            best_criterion = max_inc_max_weight;
            
            nno_improvement = 0;
        }
        else
        {
            ++nno_improvement;
            
            // convergence criteria #2
            if(nno_improvement > max_no_improvement)
            {
                std::cout << "breaking on condition (nno_improvement > max_no_improvement)" << '\n';
                break;
            }
        }
        
        // UPDATE NUMBER OF MC ITERATIONS
        K = static_cast<int>(std::round((1.0 - K_rate)*K + K_rate*K_end));
    }
}

